Identifying model bias¶

In this second part we will use explanation methods to identify a faulty classifier that was trained on biased data. Specifically, each image contains an artifact whose color is related to the class of the image. A model trained with such images will likely learn to disregard the image content entirely and only focus on the artifact to make a prediction. You will use one of the explanation methods implemented in the first part to spot the issue.

Altough in this example the bias was introduced artificially, it's not uncommon to see this kind of telling artifacts in real-world datasets. For example, in a dataset of X-ray scans, one might find identifiers along the edge or marks left by doctors that could hinder the learning of a model.

Setup¶

In [1]:
!pip install "jax[cuda]" -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'

!pip install \
  flax optax \
  'git+https://github.com/n2cholas/jax-resnet.git' \
  tensorflow-datasets \
  better_exceptions
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Requirement already satisfied: jax[cuda] in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (0.6.2)
Requirement already satisfied: jaxlib<=0.6.2,>=0.6.2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (0.6.2)
Requirement already satisfied: ml_dtypes>=0.5.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (0.5.3)
Requirement already satisfied: numpy>=1.26 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.26.4)
Requirement already satisfied: opt_einsum in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (3.4.0)
Requirement already satisfied: scipy>=1.12 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.15.3)
INFO: pip is looking at multiple versions of jax[cuda] to determine which version is compatible with other requirements. This could take a while.
Collecting jax[cuda]
  Using cached jax-0.6.1-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib<=0.6.1,>=0.6.1 (from jax[cuda])
  Using cached jaxlib-0.6.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB)
Collecting jax[cuda]
  Using cached jax-0.6.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.6.0,>=0.6.0 (from jax[cuda])
  Using cached jaxlib-0.6.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB)
Collecting jax[cuda]
  Using cached jax-0.5.3-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.3,>=0.5.3 (from jax[cuda])
  Using cached jaxlib-0.5.3-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB)
Collecting jax[cuda]
  Using cached jax-0.5.2-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.2,>=0.5.1 (from jax[cuda])
  Using cached jaxlib-0.5.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (978 bytes)
Collecting jax[cuda]
  Using cached jax-0.5.1-py3-none-any.whl.metadata (22 kB)
  Using cached jax-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.0,>=0.5.0 (from jax[cuda])
  Using cached jaxlib-0.5.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (978 bytes)
Collecting jax[cuda]
  Using cached jax-0.4.38-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.38,>=0.4.38 (from jax[cuda])
  Using cached jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB)
INFO: pip is still looking at multiple versions of jax[cuda] to determine which version is compatible with other requirements. This could take a while.
Collecting jax[cuda]
  Using cached jax-0.4.37-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.37,>=0.4.36 (from jax[cuda])
  Using cached jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB)
Collecting jax[cuda]
  Using cached jax-0.4.36-py3-none-any.whl.metadata (22 kB)
  Using cached jax-0.4.35-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.35,>=0.4.34 (from jax[cuda])
  Using cached jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes)
  Using cached jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes)
Collecting jax[cuda]
  Using cached jax-0.4.34-py3-none-any.whl.metadata (22 kB)
  Using cached jax-0.4.33-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.33,>=0.4.33 (from jax[cuda])
  Using cached jaxlib-0.4.33-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes)
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C.
Collecting jax[cuda]
  Using cached jax-0.4.31-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.31,>=0.4.30 (from jax[cuda])
  Using cached jaxlib-0.4.31-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes)
Collecting jax[cuda]
  Using cached jax-0.4.30-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.30,>=0.4.27 (from jax[cuda])
  Using cached jaxlib-0.4.30-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB)
Collecting jax[cuda]
  Using cached jax-0.4.29-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.28-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.27-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.26-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.25-py3-none-any.whl.metadata (24 kB)
  Using cached jax-0.4.24-py3-none-any.whl.metadata (24 kB)
  Using cached jax-0.4.23-py3-none-any.whl.metadata (24 kB)
  Using cached jax-0.4.22-py3-none-any.whl.metadata (24 kB)
  Using cached jax-0.4.21-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.20-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.19-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.18-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.17-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.16-py3-none-any.whl.metadata (29 kB)
  Using cached jax-0.4.14.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.13.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.12.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.11.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.10.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.9.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.8.tar.gz (1.2 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.7.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.6.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.5.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.4.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.3.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.2.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.1.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.25.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: typing_extensions in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (4.15.0)
  Using cached jax-0.3.24.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.23.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: absl-py in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (2.3.1)
Requirement already satisfied: etils[epath] in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.13.0)
  Using cached jax-0.3.22.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.21.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.20.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.19.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.17.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.16.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.15.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.14.tar.gz (990 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.13.tar.gz (951 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.12.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.11.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.10.tar.gz (939 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.9.tar.gz (937 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.8.tar.gz (935 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.7.tar.gz (944 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.6.tar.gz (936 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.5.tar.gz (946 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.4.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.3.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.2.tar.gz (926 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.1.tar.gz (912 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.0.tar.gz (896 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.28.tar.gz (887 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.27.tar.gz (873 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.26.tar.gz (850 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.25.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.24.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.22-py3-none-any.whl
WARNING: jax 0.2.22 does not provide the extra 'cuda'
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.6.2
    Uninstalling jax-0.6.2:
      Successfully uninstalled jax-0.6.2
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.90 requires jax>=0.4.27, but you have jax 0.2.22 which is incompatible.
flax 0.10.7 requires jax>=0.6.0, but you have jax 0.2.22 which is incompatible.
optax 0.2.6 requires jax>=0.5.3, but you have jax 0.2.22 which is incompatible.
orbax-checkpoint 0.11.25 requires jax>=0.6.0, but you have jax 0.2.22 which is incompatible.
Successfully installed jax-0.2.22
Collecting git+https://github.com/n2cholas/jax-resnet.git
  Cloning https://github.com/n2cholas/jax-resnet.git to /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-req-build-8vvkz9sb
  Running command git clone --filter=blob:none --quiet https://github.com/n2cholas/jax-resnet.git /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-req-build-8vvkz9sb
  Resolved https://github.com/n2cholas/jax-resnet.git to commit 5b00735aa0a68ec239af4a728ad4a596c1b551f6
  Preparing metadata (setup.py) ... done
Requirement already satisfied: flax in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (0.10.7)
Requirement already satisfied: optax in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (0.2.6)
Requirement already satisfied: tensorflow-datasets in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (4.9.9)
Requirement already satisfied: better_exceptions in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (0.3.3)
Requirement already satisfied: jax in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax-resnet==0.0.4) (0.2.22)
Requirement already satisfied: jaxlib in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax-resnet==0.0.4) (0.6.2)
Collecting jax (from jax-resnet==0.0.4)
  Using cached jax-0.6.2-py3-none-any.whl.metadata (13 kB)
Requirement already satisfied: msgpack in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (1.1.1)
Requirement already satisfied: orbax-checkpoint in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (0.11.25)
Requirement already satisfied: tensorstore in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (0.1.77)
Requirement already satisfied: rich>=11.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (14.1.0)
Requirement already satisfied: typing_extensions>=4.2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (4.15.0)
Requirement already satisfied: PyYAML>=5.4.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (6.0.3)
Requirement already satisfied: treescope>=0.1.7 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (0.1.10)
Requirement already satisfied: absl-py>=0.7.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (2.3.1)
Requirement already satisfied: chex>=0.1.87 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (0.1.90)
Requirement already satisfied: numpy>=1.18.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (1.26.4)
Requirement already satisfied: dm-tree in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.1.9)
Requirement already satisfied: etils>=1.6.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (1.13.0)
Requirement already satisfied: immutabledict in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.2.1)
Requirement already satisfied: promise in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (2.3)
Requirement already satisfied: protobuf>=3.20 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.21.12)
Requirement already satisfied: psutil in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (7.1.0)
Requirement already satisfied: pyarrow in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (21.0.0)
Requirement already satisfied: requests>=2.19.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (2.32.5)
Requirement already satisfied: simple_parsing in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.1.7)
Requirement already satisfied: tensorflow-metadata in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (1.17.2)
Requirement already satisfied: termcolor in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (3.1.0)
Requirement already satisfied: toml in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.10.2)
Requirement already satisfied: tqdm in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.67.1)
Requirement already satisfied: wrapt in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (1.17.3)
Requirement already satisfied: toolz>=0.9.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from chex>=0.1.87->optax) (1.0.0)
Requirement already satisfied: fsspec in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (2025.9.0)
Requirement already satisfied: importlib_resources in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (6.5.2)
Requirement already satisfied: zipp in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (3.23.0)
Requirement already satisfied: einops in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (0.8.1)
Requirement already satisfied: ml_dtypes>=0.5.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (0.5.3)
Requirement already satisfied: opt_einsum in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (3.4.0)
Requirement already satisfied: scipy>=1.12 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (1.15.3)
Requirement already satisfied: charset_normalizer<4,>=2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (3.4.3)
Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (2025.8.3)
Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from rich>=11.1->flax) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from rich>=11.1->flax) (2.19.2)
Requirement already satisfied: mdurl~=0.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2)
Requirement already satisfied: attrs>=18.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from dm-tree->tensorflow-datasets) (25.3.0)
Requirement already satisfied: nest_asyncio in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.6.0)
Requirement already satisfied: aiofiles in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (24.1.0)
Requirement already satisfied: humanize in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (4.13.0)
Requirement already satisfied: simplejson>=3.16.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (3.20.2)
Requirement already satisfied: six in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from promise->tensorflow-datasets) (1.17.0)
Requirement already satisfied: docstring-parser<1.0,>=0.15 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from simple_parsing->tensorflow-datasets) (0.17.0)
Using cached jax-0.6.2-py3-none-any.whl (2.7 MB)
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.2.22
    Uninstalling jax-0.2.22:
      Successfully uninstalled jax-0.2.22
Successfully installed jax-0.6.2
In [21]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"

import tensorflow as tf

tf.get_logger().setLevel("WARNING")
tf.config.experimental.set_visible_devices([], "GPU")

import json
from functools import partial
from pathlib import Path

import flax
import flax.core
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax_resnet
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
import torch
import tqdm
from IPython.display import Markdown, display

Utils¶

In [4]:
CLASS_NAMES = (
    "tench",
    "English springer",
    "cassette player",
    "chain saw",
    "church",
    "French horn",
    "garbage truck",
    "gas pump",
    "golf ball",
    "parachute",
)

BIAS_COLORS = [
    [0.1215, 0.4666, 0.7058],
    [0.5490, 0.3372, 0.2941],
    [1.0000, 0.4980, 0.0549],
    [0.1725, 0.6274, 0.1725],
    [0.8392, 0.1529, 0.1568],
    [0.5803, 0.4039, 0.7411],
    [0.8901, 0.4666, 0.7607],
    [0.7372, 0.7411, 0.1333],
    [0.0901, 0.7450, 0.8117],
    [0.4980, 0.4980, 0.4980],
]

RED = np.array([1.0, 0, 0])
BLUE = np.array([0, 0, 1.0])


def create_dataset(data_dir: str, batch_size: int):
    ds_builder = tfds.builder("imagenette/320px-v2", data_dir=data_dir)
    ds_builder.download_and_prepare()

    ds_val = ds_builder.as_dataset("validation", as_supervised=True)
    ds_val = ds_val.map(resize)
    ds_val = ds_val.map(add_bias_pixel)
    ds_val = ds_val.batch(batch_size)
    ds_val = tfds.as_numpy(ds_val)

    return ds_val


def resize(image, label):
    image = tf.image.resize_with_pad(image, 224, 224)
    return image / 255.0, label


def add_bias_pixel(image, label):
    hw_ = tf.reduce_sum(image, axis=[0, 1])
    hw_ = tf.cast(hw_, tf.int32) % 30 + 140
    h = hw_[0]
    w = hw_[1]
    color = tf.constant(BIAS_COLORS)[label]
    mask = tf.meshgrid(tf.range(224), tf.range(224), indexing="ij")
    mask = (
        (mask[0] % 12 != tf.cast(label, tf.int32) + 1)
        & (mask[0] > h)
        & (mask[0] < h + 12)
        & (mask[1] % 5 < 2)
        & (mask[1] > w)
        & (mask[1] < w + 30)
    )
    image = tf.where(mask[:, :, None], color, image)
    return image, label


def load_checkpoint(path):
    @jax.jit
    def logits_fn(variables, img):
        # img: [H, W, C], float32 in range [0, 1]
        assert img.ndim == 3
        img = normalize_for_resnet(img)
        logits = model.apply(variables, img[None, ...], mutable=False)[0]
        return logits.max(), logits

    path = Path(path)
    args = json.loads(Path.read_text(path / "args.json"))
    variables_path = path / "variables.npy"

    model = getattr(jax_resnet.resnet, f"ResNet{args['resnet_size']}")(n_classes=10)
    variables = model.init(jax.random.PRNGKey(0), jnp.zeros((1, 224, 224, 3)))
    variables = flax.serialization.from_bytes(variables, variables_path.read_bytes())

    return logits_fn, variables


def normalize_for_resnet(images):
    # images: [..., H, W, 3], float32, range [0, 1]
    mean = jnp.array([0.485, 0.456, 0.406])
    std = jnp.array([0.229, 0.224, 0.225])
    return (images - mean) / std


def imagenet_to_imagenette_logits(logits):
    """Select the 10 imagenette classes from the 1000 imagenet classes."""
    return logits[..., [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]]


def show_images(images, labels=None, logits=None, ncols=4, width_one_img_inch=3.0):
    B, H, W, *_ = images.shape
    nrows = int(np.ceil(B / ncols))
    fig, axs = plt.subplots(
        nrows,
        ncols,
        figsize=width_one_img_inch * np.array([1, H / W]) * np.array([ncols, nrows]),
        sharex=True,
        sharey=True,
        squeeze=False,
        facecolor="white",
    )
    for b in range(B):
        ax = axs.flat[b]
        ax.imshow(images[b])
        if labels is not None:
            ax.set_title(CLASS_NAMES[labels[b]])
        if logits is not None:
            pred = logits[b].argmax()
            prob = jax.nn.softmax(logits[b])[pred]
            color = (
                "blue" if labels is None else ("green" if labels[b] == pred else "red")
            )
            p = mpl.patches.Patch(color=color, label=f"{prob:.2%} {CLASS_NAMES[pred]}")
            ax.legend(handles=[p])
    fig.tight_layout()
    display(fig)
    plt.close(fig)


@jax.jit
def blend(a, b, alpha: float):
    return (1 - alpha) * a + alpha * b

Metrics¶

These tables summarize the hyperparameters used to train the models and their performance.

Accuracy and loss are reported for two slightly different versions of the validation set: one that contains a clear source of bias and one that doesn't. If we only had access to the biased dataset and we did not know about the bias, we might be tempted to choose the first model, which achieves a much higher accuracy than the second.

In [5]:
paths = [Path('output/biased'), Path('output/unbiased')]

df_args = (
    pd.DataFrame([json.loads(Path.read_text(p / "args.json")) for p in paths])
    .drop(columns="output")
    .set_index("run_id")
    .sort_index()
)
display(df_args)

df_test = pd.DataFrame(
    [
        {"run_id": p.parent.name, **json.loads(line)}
        for p in paths
        for line in Path.read_text(p / "test.json").splitlines()
    ],
)

display(
    df_test.pivot_table(
        index="run_id", columns="bias_pixel", values=["accuracy", "loss"]
    )
    .sort_index()
    .style.format("{:.3f}")
    .format("{:.1%}", subset="accuracy")
)
bias_pixel resnet_size epochs seed learning_rate weight_decay batch_size
run_id
biased True 18 10 5807 0.001 0.0001 64
unbiased False 18 10 5807 0.001 0.0001 64
  accuracy loss
bias_pixel False True False True
run_id        
output 48.7% 85.2% 2.774 0.551

Model comparison¶

Task 1¶

Reimplement one of the explanation methods from the previous notebook and use it to visualize the most important regions for the first few batches of images.

  • Can you spot the model that was trained on biased data?
  • Which explanation method did you choose? Can you motivate your choice? Did you try others to see what worked best?
  • Can you summarize the explanation method and suggest why it works best here?

Add your comments below:

  • Yes, the Model B is trained on biased data. It incorrectly predicts a chain saw as a cassette player in Batch 0 and a gas pump as Church in Batch 2, showing that it relies on spurious features rather than the actual object.
  • I have chosen the Integrated Gradients explanation method as it had the lowest deletion score from previous work and was concluded to the best method.
  • I have also implemented Occlusion, Grad_X_Input and Grad_CAM method to compare the results. I see that Grad_CAM results in the same predictions for both the models. This reveals the shortcut learning of Grad_CAM as it activates similar feature map regions for learning the object and the artifact. Grad_X_Input method on the other hand predicts identically to Integrated Gradients suggesting it captures similar pixel-level relevance but with more noise. Occlusion method also aligned with Integrated Gradients. Even though it does not depend on gradients, Occlusion method was able to expose bias fairly well as it directly tests the causal impact of masking image regions.
  • Integrated Gradients explanation method highlights pixels that consistently increase the prediction confidence as they are added to the image. It is the best at exposing bias, as it directly shows pixel-level causal attributions, making it clear whether the model relies on the object or artifact.

Note: explanation_fn will be called with logits_fn, variables, and images. Extra parameters can be put in kwargs and partial will take care of them.

In [33]:
@partial(jax.jit, static_argnames=["steps"])
def prepare_integrated_gradients(img, steps: int):
    assert img.ndim == 3
    return img[None, :, :, :] * jnp.linspace(1, 0, num=steps)[:, None, None, None]

@jax.jit
def normalize_max(x):
    """Normalize a vector between -1 and 1."""
    res = x / jnp.abs(x).max()
    res = jnp.clip(res, a_min=-1, a_max=1)
    return res


def integrated_grad_fn(logits_fn, variables, img, steps: int):
    H, W, _ = img.shape
    # model's predicted class
    _, logits_orig = logits_fn(variables, img)
    idx = logits_orig.argmax()
    baseline = jnp.zeros_like(img) 
    images = prepare_integrated_gradients(img, steps).reshape(-1, H, W, 3)
    _, logits = jax.vmap(logits_fn, (None, 0))(variables, images)
    # function to call grad on idx-th element of logits
    def grads_idx_fn(variables, img_):
        logit_max, logit = logits_fn(variables, img_)
        val = logit[idx]
        return val, logit_max
    value_and_grad_fn = jax.value_and_grad(grads_idx_fn, argnums=1, has_aux=True)
    (_,_), grads = jax.vmap(lambda im: value_and_grad_fn(variables, im), in_axes=0)(images)
    avg_grads = grads.mean(axis=0)
    ig = (img - baseline) * avg_grads
    heat = jnp.linalg.norm(ig, axis=-1)
    grads = normalize_max(heat)
    # logits: [num_classes]
    # grads:  [H, W]
    return logits_orig, grads
    
def explanation_fn(logits_fn, variables, img):
    H, W, _ = img.shape
    logits, attrib = integrated_grad_fn(logits_fn, variables, img, 25)
    # logits: [num_classes]
    # attrib: [H, W]
    return logits, attrib


kwargs = {}
explanation_fn = partial(explanation_fn, **kwargs)
explanation_fn = jax.vmap(explanation_fn, in_axes=(None, None, 0))
explanation_fn = jax.jit(explanation_fn, static_argnames=["logits_fn"])
In [ ]:
### Integrated Gradients method 
ds_val = create_dataset(".", batch_size=4)

logits_fn_a, variables_a = load_checkpoint("output/biased")
logits_fn_b, variables_b = load_checkpoint("output/unbiased")

for batch_idx, (images, labels) in enumerate(ds_val):
    display(Markdown(f"## Batch {batch_idx}"))

    display(Markdown(f"### Model A"))
    logits, relevance = explanation_fn(logits_fn_a, variables_a, images)
    show_images(
        blend(images, RED, relevance.clip(min=0)[..., None]),
        labels,
        logits,
    )

    display(Markdown(f"### Model B"))
    logits, relevance = explanation_fn(logits_fn_b, variables_b, images)
    show_images(
        blend(images, RED, relevance.clip(min=0)[..., None]),
        labels,
        logits,
    )

    if batch_idx >= 2:
        break

Batch 0¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image

Batch 1¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image

Batch 2¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image
In [31]:
def prepare_occlusions(img, steps: int):
    H, W, _ = img.shape
    imgs = jnp.tile(img, (steps, steps, 1, 1, 1))
    for i in range(0, steps):
        for j in range(0, steps):
            imgs = imgs.at[i, j, int(i*H/steps):int((i+1)*H/steps), int(j*W/steps):int((j+1)*W/steps), :].set(0)
    # imgs: [steps, steps, H, W, 3]
    return imgs

def occlusion_fn(logits_fn, variables, img, steps: int):
    H, W, _ = img.shape
    _, logits_orig = logits_fn(variables, img)
    probs = nn.softmax(logits_orig)
    idx = logits_orig.argmax()
    imgs = prepare_occlusions(img, steps)
    logits_occ_fn = jax.vmap(                                   
        jax.vmap(logits_fn, (None,0)),
        (None,0)
    )
    _, logits_occ = logits_occ_fn(variables, imgs)
    probs_occ = nn.softmax(logits_occ, axis=-1)
    relevance = probs[idx] - probs_occ[..., idx]
    relevance = jax.image.resize(relevance, (H, W), method="bilinear")
    attrib = normalize_max(relevance)
    # logits_orig: [num_classes]
    # attrib:      [H, W]
    return logits_orig, attrib


def explanation_fn(logits_fn, variables, img):
    H, W, _ = img.shape
    steps = 15
    logits, attrib = occlusion_fn(logits_fn, variables, img, steps)
    # logits: [num_classes]
    # attrib: [H, W]
    return logits, attrib


kwargs = {}
explanation_fn = partial(explanation_fn, **kwargs)
explanation_fn = jax.vmap(explanation_fn, in_axes=(None, None, 0))
explanation_fn = jax.jit(explanation_fn, static_argnames=["logits_fn"])
In [32]:
### Occlusion method

ds_val = create_dataset(".", batch_size=4)

logits_fn_a, variables_a = load_checkpoint("output/biased")
logits_fn_b, variables_b = load_checkpoint("output/unbiased")

for batch_idx, (images, labels) in enumerate(ds_val):
    display(Markdown(f"## Batch {batch_idx}"))

    display(Markdown(f"### Model A"))
    logits, relevance = explanation_fn(logits_fn_a, variables_a, images)
    show_images(
        blend(images, RED, relevance.clip(min=0)[..., None]),
        labels,
        logits,
    )

    display(Markdown(f"### Model B"))
    logits, relevance = explanation_fn(logits_fn_b, variables_b, images)
    show_images(
        blend(images, RED, relevance.clip(min=0)[..., None]),
        labels,
        logits,
    )

    if batch_idx >= 2:
        break

Batch 0¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image

Batch 1¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image

Batch 2¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image
In [34]:
def grad_x_input_fn(logits_fn, variables, img):
    H, W, _ = img.shape
    logits_vg_fn = jax.value_and_grad(logits_fn, argnums=1, has_aux=True)
    (_, logits), grads = logits_vg_fn(variables, img)
    grads_x = img * grads
    heat_x = jnp.linalg.norm(grads_x, axis=-1) 
    grad = normalize_max(heat_x)
    # logits: [num_classes]
    # grad:   [H, W]
    return logits, grad


def explanation_fn(logits_fn, variables, img):
    H, W, _ = img.shape
    logits, attrib = grad_x_input_fn(logits_fn, variables, img)
    # logits: [num_classes]
    # attrib: [H, W]
    return logits, attrib


kwargs = {}
explanation_fn = partial(explanation_fn, **kwargs)
explanation_fn = jax.vmap(explanation_fn, in_axes=(None, None, 0))
explanation_fn = jax.jit(explanation_fn, static_argnames=["logits_fn"])
In [35]:
### Grad_X_Input method

ds_val = create_dataset(".", batch_size=4)

logits_fn_a, variables_a = load_checkpoint("output/biased")
logits_fn_b, variables_b = load_checkpoint("output/unbiased")

for batch_idx, (images, labels) in enumerate(ds_val):
    display(Markdown(f"## Batch {batch_idx}"))

    display(Markdown(f"### Model A"))
    logits, relevance = explanation_fn(logits_fn_a, variables_a, images)
    show_images(
        blend(images, RED, relevance.clip(min=0)[..., None]),
        labels,
        logits,
    )

    display(Markdown(f"### Model B"))
    logits, relevance = explanation_fn(logits_fn_b, variables_b, images)
    show_images(
        blend(images, RED, relevance.clip(min=0)[..., None]),
        labels,
        logits,
    )

    if batch_idx >= 2:
        break

Batch 0¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image

Batch 1¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image

Batch 2¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image
In [17]:
def grad_cam_fn(fns, variables, img):
    H, W, _ = img.shape
    backbone_fn = fns["backbone"]
    gap_classifier_fn  = fns["gap_cls"]
    backbone_vars = variables["backbone"]
    gap_classifier_vars  = variables["gap_cls"]

    # apply image through backbone 
    features = backbone_fn(backbone_vars, img)
    _, logits = gap_classifier_fn(gap_classifier_vars, features)
    # fix the target class c
    c = jnp.argmax(logits)   

    # scalar logit function for class c
    def class_logit_fn(vars_, feats_):
        _, logits = gap_classifier_fn(vars_, feats_)
         # scalar Y^c
        return logits[c]

    #  gradients wrt features for class c
    vgf = jax.value_and_grad(class_logit_fn, argnums=1, has_aux=False)
    _, grads = vgf(gap_classifier_vars, features) 
    alpha = grads.mean(axis=(0,1))
    relevance = jnp.einsum("hwc,c->hw", features, alpha)
    relevance = jnp.maximum(relevance, 0)
    #print("relevance:", relevance.shape)
    # resize to input image size
    relevance_resized = jax.image.resize(
        relevance, (H, W), method="bilinear"
    )
    relevance_resized = normalize_max(relevance_resized)
    #print("relevance_resized:", relevance_resized.shape)
    # logits: [num_classes]
    # grad:   [H, W]
    return logits, relevance_resized


def load_resnet_for_grad_cam(size):
    @jax.jit
    def backbone_fn(variables, img):
        # img:   [H, W, C], float32 in range [0, 1]
        # feats: [h, w, c], float32
        img = normalize_for_resnet(img)
        feats = backbone.apply(variables, img[None, ...], mutable=False)[0]
        return feats

    @jax.jit
    def gap_classifier_fn(variables, feats):
        # feats:  [h, w, c], float32
        # logit:  float32
        # logits: [10], float32
        logits = gap_classifier.apply(variables, feats[None, ...], mutable=False)[0]
        logits = imagenet_to_imagenette_logits(logits)
        return logits.max(), logits

    ResNet, variables = jax_resnet.pretrained_resnet(size)
    model = ResNet()

    backbone = nn.Sequential(model.layers[:-2])
    backbone_vars = jax_resnet.slice_variables(variables, start=0, end=-2)
    gap_classifier = nn.Sequential(model.layers[-2:])
    gap_classifier_vars = jax_resnet.slice_variables(variables, start=len(model.layers) - 2, end=None)
    return (
        flax.core.freeze({"backbone": backbone_fn, "gap_cls": gap_classifier_fn}),
        flax.core.freeze({"backbone": backbone_vars, "gap_cls": gap_classifier_vars}),
    )

def explanation_fn(logits_fn, variables, img):
    H, W, _ = img.shape
    fns, variables = load_resnet_for_grad_cam(size=18)
    logits, attrib = grad_cam_fn(fns, variables, img)
    # logits: [num_classes]
    # attrib: [H, W]
    return logits, attrib


kwargs = {}
explanation_fn = partial(explanation_fn, **kwargs)
explanation_fn = jax.vmap(explanation_fn, in_axes=(None, None, 0))
explanation_fn = jax.jit(explanation_fn, static_argnames=["logits_fn"])
In [22]:
### Grad-CAM method

ds_val = create_dataset(".", batch_size=4)

logits_fn_a, variables_a = load_checkpoint("output/biased")
logits_fn_b, variables_b = load_checkpoint("output/unbiased")

for batch_idx, (images, labels) in enumerate(ds_val):
    display(Markdown(f"## Batch {batch_idx}"))

    display(Markdown(f"### Model A"))
    logits, relevance = explanation_fn(logits_fn_a, variables_a, images)
    show_images(
        blend(images, RED, relevance.clip(min=0)[..., None]),
        labels,
        logits,
    )

    display(Markdown(f"### Model B"))
    logits, relevance = explanation_fn(logits_fn_b, variables_b, images)
    show_images(
        blend(images, RED, relevance.clip(min=0)[..., None]),
        labels,
        logits,
    )

    if batch_idx >= 2:
        break

Batch 0¶

Model A¶

Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
No description has been provided for this image

Model B¶

Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
No description has been provided for this image

Batch 1¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image

Batch 2¶

Model A¶

No description has been provided for this image

Model B¶

No description has been provided for this image

Task 2¶

How long did it take you to complete this practical? This information is valuable to us to balance the difficulty of different practicals. 1.5 hour

In [ ]: